pytorch - nn

Lecture 23

Dr. Colin Rundel

Odds & Ends

Torch models

Implementation details:

  • Models are implemented as a class inheriting from torch.nn.Module

  • Must implement constructor and forward() method

    • __init__() should call parent constructor via super()

      • Use torch.nn.Parameter() to indicate model parameters
    • forward() should implement the model - constants + parameters -> return predictions

Fitting proceedure:

  • For each iteration of solver:

    • Get current predictions via a call to forward() or equivalent.

    • Calculate a (scalar) loss or equivalent

    • Call backward() method on loss

    • Use built-in optimizer (step() and then zero_grad() if necessary)

From last time

class Model(torch.nn.Module):
    def __init__(self, X, y, beta=None):
        super().__init__()
        self.X = X
        self.y = y
        if beta is None:
          beta = torch.zeros(X.shape[1])
        beta.requires_grad = True
        self.beta = torch.nn.Parameter(beta)
        
    def forward(self, X):
        return X @ self.beta
    
    def fit(self, opt, n=1000, loss_fn = torch.nn.MSELoss()):
      losses = []
      for i in range(n):
          loss = loss_fn(
            self(self.X).squeeze(), 
            self.y.squeeze()
          )
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

What is self(self.X)?

This is (mostly) just short hand for calling self.forward(X) to generate the output tensors from the current value(s) of the parameters.

This is done via the __call__() method in the torch.nn.Module class. __call__() allows python classes to be invoked like functions.


class greet:
  def __init__(self, greeting):
    self.greeting = greeting
  def __call__(self, name):
    return self.greeting + " " + name
hello = greet("Hello")
hello("Jane")
'Hello Jane'
gm = greet("Good morning")
gm("Bob")
'Good morning Bob'

MNIST & Logistic models

MNIST handwritten digits - simplified

from sklearn.datasets import load_digits
digits = load_digits()
X = digits.data
X.shape
(1797, 64)
X[0:2]
array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,
         0.,  0.,  0., 13., 15., 10., 15.,
         5.,  0.,  0.,  3., 15.,  2.,  0.,
        11.,  8.,  0.,  0.,  4., 12.,  0.,
         0.,  8.,  8.,  0.,  0.,  5.,  8.,
         0.,  0.,  9.,  8.,  0.,  0.,  4.,
        11.,  0.,  1., 12.,  7.,  0.,  0.,
         2., 14.,  5., 10., 12.,  0.,  0.,
         0.,  0.,  6., 13., 10.,  0.,  0.,
         0.],
       [ 0.,  0.,  0., 12., 13.,  5.,  0.,
         0.,  0.,  0.,  0., 11., 16.,  9.,
         0.,  0.,  0.,  0.,  3., 15., 16.,
         6.,  0.,  0.,  0.,  7., 15., 16.,
        16.,  2.,  0.,  0.,  0.,  0.,  1.,
        16., 16.,  3.,  0.,  0.,  0.,  0.,
         1., 16., 16.,  6.,  0.,  0.,  0.,
         0.,  1., 16., 16.,  6.,  0.,  0.,
         0.,  0.,  0., 11., 16., 10.,  0.,
         0.]])
y = digits.target
y.shape
(1797,)
y[0:10]
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Example digits

Test train split

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, shuffle=True, random_state=1234
)
X_train.shape
(1437, 64)
y_train.shape
(1437,)
X_test.shape
(360, 64)
y_test.shape
(360,)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
lr = LogisticRegression(
  penalty=None
).fit(
  X_train, y_train
)
accuracy_score(y_train, lr.predict(X_train))
1.0
accuracy_score(y_test, lr.predict(X_test))
0.9583333333333334

As Torch tensors

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test)
X_train.shape
torch.Size([1437, 64])
y_train.shape
torch.Size([1437])
X_test.shape
torch.Size([360, 64])
y_test.shape
torch.Size([360])
X_train.dtype
torch.float32
y_train.dtype
torch.int64
X_test.dtype
torch.float32
y_test.dtype
torch.int64

PyTorch Model

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses = []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
      
      return losses

Cross entropy loss

model = mnist_model(64, 10)
l = model.fit(X_train, y_train, X_test, y_test)

Cross entropy loss

From the pytorch documentation:

\[ \ell(x, y)=L=\left\{l_1, \ldots, l_N\right\}^{\top}, \quad l_n=-w_{y_n} \log \frac{\exp \left(x_{n, y_n}\right)}{\sum_{c=1}^C \exp \left(x_{n, c}\right)} \]

\[ \ell(x, y)= \begin{cases}\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot 1\left\{y_n \neq \text { ignore_index }\right\}} l_n, & \text { if reduction }=\text { 'mean' } \\ \sum_{n=1}^N l_n, & \text { if reduction }=\text { 'sum' }\end{cases} \]

Out-of-sample accuracy

model(X_test)
tensor([[ -60.4698,  -60.3460,   -3.2359,
           -4.6123,  -30.8070,  -52.0707,
         -102.8619,   52.4684,  -24.2276,
          -27.4013],
        [ -35.2080,   48.2467,  -21.1850,
           16.8739,    3.5356,  -21.0426,
          -26.7105,  -27.8124,   10.2930,
           56.1511],
        [ -50.5956,  -15.3633,   17.9361,
            5.6485,  -40.0762,  -48.8120,
          -60.7111,   69.6762,   -1.5164,
            9.7559],
        [   1.1833,  -33.2534,   -0.3585,
          -36.4596,   -6.4512,  -19.3562,
           51.9370, -120.3680,  -10.1044,
          -43.0745],
        [  28.2609,  -19.5564,  -50.7594,
          -65.2076,  -12.3891,  -49.6241,
          -17.4184,  -60.7233,  -13.3392,
          -14.7305],
        [ -51.9549,   16.1311,   48.0923,
           11.2682,  -49.9447,  -15.4018,
          -19.4232,    1.9681,  -37.9048,
          -23.1772],
        [  -8.8474,  -13.6621,  -15.8109,
          -58.2890,   23.8521,  -17.9521,
           14.3023,  -49.3965,  -44.6686,
          -46.5255],
        [ -43.3213,  -51.2646,   -4.1250,
           36.7883, -107.4316,  -28.0748,
         -108.8582,  -66.3883,  -36.9246,
            1.4140],
        [ -24.8235,   -5.3153,  -23.6900,
          -31.3845,  -24.5766,  -25.6308,
           34.7935,  -98.9865,    3.6377,
          -18.1407],
        [ -39.9839,    2.4435,    3.9232,
           30.5909,  -56.2090,  -41.5259,
          -55.5499,  -36.6804,   -4.0164,
            2.5643],
        [ -48.5294,  -67.4332,   18.1071,
           -6.4645,    7.2488,  -47.2639,
          -60.9697,   10.8668,  -32.8081,
          -64.0977],
        [ -75.4684,  -20.6770,  -19.9017,
            0.9568,  -45.6811,  -28.1770,
          -25.0667,  -76.0366,   53.4137,
          -10.5690],
        [ -44.0514,  -28.9370,  -22.1685,
          -14.8544,  -23.6289,   -8.1833,
          -53.6284,    0.5759,  -38.3503,
          -55.2882],
        [  -5.7789,  -34.3641,  -31.3128,
          -30.3968,  -40.7891,  -47.3578,
          -64.9798,  -50.8311,   -1.9498,
           15.5228],
        [  -1.1835,   33.8432,  -33.9127,
            6.7529,   58.3464,  -32.2342,
           19.7274,  -62.9129,  -15.5001,
          -10.1331],
        [ -30.3365,  -64.1364,  -18.1107,
           21.4814,  -74.7543,  -49.6226,
         -101.4309,  -36.1643,   -6.0478,
          -10.9746],
        [ -13.8172,    5.8836,   -7.6360,
          -11.2843,   -6.7821,  -65.8260,
           -4.0350,  -16.0150,    8.6830,
          -18.2703],
        [ -53.7174,  -54.1730,   10.6402,
            4.7685,  -46.1619,  -26.7579,
          -48.4387,   44.1966,  -18.2875,
          -35.5813],
        [ -61.6316,  -25.6416,  -28.7592,
           -9.5010,  -17.8787,  -49.3770,
          -42.4279,  -63.5348,   49.9255,
           -8.2657],
        [  22.2697,   10.1308,  -40.9131,
          -34.5510,   57.2394,  -27.1202,
           51.7749,  -96.0782,    4.0343,
          -18.7105],
        [  35.4451,  -54.7640,  -29.4009,
          -65.0773,  -31.3074,  -52.9765,
          -50.7534,  -67.7673,  -21.6395,
          -36.2219],
        [ -55.1902,  -43.5276,   22.0906,
           39.8346,  -98.3593,    7.3984,
          -71.0787,  -63.1775,  -34.8546,
           39.4297],
        [ -35.8509,  -39.0902,  -16.9043,
            9.0145,  -62.7864,  -29.8620,
          -87.4322,  -56.2781,   -4.5117,
           41.9364],
        [ -84.0603,   36.1035,   19.3447,
           18.1973,  -30.6207,   -2.5074,
           14.9145,  -69.6632,   14.7989,
          -32.9170],
        [ -49.1929,  -27.0290,    4.6255,
           35.0274,  -23.3463,    2.5409,
          -82.4045,  -52.3015,  -40.4974,
           -6.6615],
        [ -28.6515,   11.9169,   -4.7532,
          -28.1880,  -36.2217,  -11.6573,
           56.3938,  -92.3155,  -19.4293,
           -3.6325],
        [ -10.2365,  -25.4885,    4.6966,
          -18.4677,  -39.1529,  -28.5339,
           58.8753,  -94.7391,   -7.8719,
          -44.5395],
        [  36.5956,  -54.2408,   -7.9828,
          -31.4007,  -60.0725,  -73.9617,
          -35.5274,  -63.9784,  -17.4701,
            7.2594],
        [ -67.8348,  -22.2471,  -58.2029,
          -15.3862,   -9.1726,   21.4612,
          -76.4956,  -22.6289,  -33.2844,
           -5.2027],
        [  21.9947,    6.6304,  -18.5833,
          -49.1579,   60.0684,    2.7041,
           38.9698,  -64.2615,  -37.0968,
          -29.3544],
        ...,
        [ -69.5320,  -53.5644,   -8.4442,
           33.8669,  -81.9889,  -26.2148,
          -97.0356,  -77.3523,  -11.2585,
          -11.7880],
        [  34.6750,   -3.9794,   -4.4232,
          -22.8697,   60.7744,   -8.3808,
           59.5475,  -42.6261,   14.3972,
          -10.2616],
        [ -73.9245,   41.7595,  -13.9660,
            7.3793,   16.7320,  -26.5570,
          -18.6693,  -25.9286,  -40.8057,
          -15.1597],
        [ -15.9094,   -5.9849,  -37.7155,
          -57.8141,    6.3571,  -33.8558,
            4.4501,  -29.1332,   -5.2438,
          -37.9425],
        [ -59.3682,   22.4388,  -16.4493,
           -1.9291,   -2.9836,  -14.8061,
          -27.6597,  -15.6523,  -23.3290,
          -14.6019],
        [ -10.3065,  -71.9873,  -37.2596,
          -22.8238,  -56.7525,  -52.5827,
          -46.2107,  -56.2524,   32.0420,
            2.7163],
        [-125.3412,  -10.9036,  -10.1358,
           -9.8955,  -58.0807,   43.1364,
          -85.5127,  -35.4018,  -17.0024,
          -24.9838],
        [  35.8807,  -85.3185,  -14.1419,
          -43.3398,  -47.6494,  -69.5635,
          -65.0889,  -59.0621,  -24.1110,
          -31.8553],
        [ -10.3379,   -4.6593,  -49.8481,
          -19.2881,   26.5490,   13.7519,
          -15.5973,   -1.1546,   10.3668,
           26.9829],
        [ -95.5513,  -25.2283,   71.5070,
           31.3183,  -65.6321,  -11.7375,
          -28.2158,  -59.8576,  -30.5740,
          -10.6293],
        [   3.9714,   -6.1320,   -7.7150,
            4.3892,   22.1828,  -46.4714,
            2.4544,   49.2627,   12.4349,
          -18.1459],
        [ -68.5562,  -29.6500,   63.9699,
           -0.8962,  -94.6507,  -47.4809,
          -62.7365,   -2.4928,  -27.2724,
          -32.5723],
        [ -48.9494,  -38.6800,   -0.8626,
           21.3021,  -80.3762,   -5.0741,
          -59.4682,  -64.6661,  -37.2080,
           17.0452],
        [-116.5415,  -22.2232,  -18.6835,
          -30.3401,  -65.8282,   32.3907,
          -85.3017,   -3.1637,  -11.1661,
          -37.6613],
        [ -53.8954,  -25.8811,   45.8346,
          -12.6007,  -61.4846,  -27.2438,
          -26.4994,    1.7727,  -27.5884,
          -54.5431],
        [ -41.2342,   25.1260,   -3.1794,
           -3.6952,  -64.1741,  -64.1356,
           51.8241,  -93.6669,  -34.1719,
           -4.2749],
        [ -87.3485,  -43.3562,  -18.9921,
           40.6294,  -86.1238,  -16.1066,
          -91.6288,  -69.5523,  -20.2742,
           -7.9495],
        [   5.7155,  -20.4136,  -36.3922,
          -29.2255,   29.6297,   -0.1888,
           -7.9959,  -14.7098,  -18.6376,
           14.8159],
        [ -72.0315,   28.6389,  -15.9866,
           -7.4088,   -0.2316,  -16.4018,
          -32.0956,  -18.0283,  -33.0930,
          -18.9966],
        [ -54.8775,  -12.1453,  -46.1965,
           -8.0844,  -40.4212,   41.8590,
          -73.7968,  -64.2573,  -32.1981,
           -5.0725],
        [  39.1486,  -71.1878,  -17.3247,
          -49.7138,  -55.8501,  -67.3489,
          -38.8621,  -69.4966,  -11.3626,
          -16.8447],
        [-117.1165,  -24.3999,   -6.8177,
          -15.4861,  -54.6202,   16.8432,
          -90.2593,   -0.5775,   18.2690,
          -21.4833],
        [   9.7740,    8.8800,  -66.5547,
          -18.5843,   40.1320,  -30.6314,
          -34.9514,    1.5537,  -20.1621,
           14.6515],
        [   8.0999,  -32.5154,  -13.4643,
          -34.1753,  -23.4296,  -18.6327,
           45.5812,  -92.9919,   -5.7646,
          -38.0658],
        [ -65.0105,  -41.7965,  -27.3865,
           16.2514,  -82.4736,  -16.6111,
         -106.3225,  -96.5700,  -40.6017,
          -18.6682],
        [ -59.5109,  -34.2753,   22.7640,
          -33.7454, -117.3810,  -28.2189,
          -60.5967,  -39.5485,  -37.7502,
          -58.8680],
        [ -99.0598,  -18.3286,  -30.5382,
           25.6998,  -48.3091,   44.6432,
          -79.6358,  -61.8885,   -2.1916,
            4.0065],
        [  39.6762,  -17.6546,  -43.0146,
          -61.7079,  -18.2184,  -53.2627,
          -35.7196,  -60.7951,  -10.3877,
           -9.1515],
        [ -37.2355,    0.4509,  -34.1900,
          -39.9283,   12.1733,  -33.8624,
          -58.8478,   34.9122,   -3.4259,
           13.3382],
        [ -80.9881,  -52.0928,  -18.7864,
           21.6604,  -77.4637,  -20.1207,
          -81.6271,  -83.8124,  -16.2177,
          -27.5814]],
       grad_fn=<SqueezeBackward0>)
val, index = torch.max(model(X_test), dim=1)
index
tensor([7, 9, 7, 6, 0, 2, 4, 3, 6, 3, 2, 8, 7,
        9, 4, 3, 8, 7, 8, 4, 0, 3, 9, 1, 3, 6,
        6, 0, 5, 4, 1, 0, 1, 2, 3, 8, 7, 6, 4,
        8, 6, 4, 4, 0, 9, 7, 8, 5, 4, 4, 4, 1,
        7, 6, 8, 2, 9, 8, 8, 0, 8, 3, 1, 8, 8,
        8, 3, 9, 1, 3, 9, 6, 9, 5, 6, 1, 9, 2,
        1, 3, 8, 7, 3, 3, 8, 3, 7, 5, 8, 2, 6,
        1, 9, 1, 6, 4, 5, 2, 2, 4, 5, 6, 7, 6,
        5, 9, 2, 4, 1, 0, 7, 6, 1, 2, 9, 5, 2,
        5, 0, 3, 2, 7, 6, 4, 9, 2, 1, 1, 6, 9,
        6, 6, 7, 4, 7, 5, 0, 9, 1, 0, 5, 6, 7,
        8, 3, 8, 3, 2, 0, 4, 6, 3, 5, 4, 6, 1,
        1, 1, 6, 1, 7, 0, 0, 7, 9, 5, 6, 1, 3,
        8, 6, 4, 7, 1, 5, 7, 4, 7, 4, 3, 2, 2,
        1, 8, 4, 4, 3, 5, 5, 9, 4, 5, 5, 9, 3,
        9, 2, 1, 2, 0, 8, 2, 8, 9, 2, 4, 6, 8,
        3, 8, 1, 0, 8, 1, 8, 5, 6, 8, 8, 1, 8,
        0, 4, 9, 7, 0, 5, 5, 6, 1, 3, 0, 5, 8,
        2, 0, 9, 8, 6, 7, 8, 4, 1, 0, 5, 2, 5,
        1, 6, 4, 7, 1, 2, 6, 4, 4, 6, 3, 2, 1,
        2, 6, 5, 2, 9, 4, 7, 0, 1, 0, 4, 3, 1,
        2, 7, 9, 8, 5, 9, 5, 7, 0, 0, 8, 4, 9,
        4, 0, 7, 7, 2, 5, 3, 5, 9, 8, 7, 9, 8,
        2, 7, 6, 3, 9, 1, 7, 9, 8, 5, 0, 2, 0,
        2, 7, 0, 9, 5, 5, 3, 6, 1, 2, 3, 9, 1,
        3, 2, 9, 3, 4, 3, 4, 1, 4, 1, 8, 5, 0,
        9, 2, 7, 2, 3, 5, 2, 6, 3, 4, 1, 5, 0,
        8, 4, 6, 3, 2, 5, 0, 7, 3])
(index == y_test).sum()
tensor(318)
(index == y_test).sum() / len(y_test)
tensor(0.8833)

Calculating Accuracy

class mnist_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.beta = torch.nn.Parameter(
          torch.randn(input_dim, output_dim, requires_grad=True)  
        )
        self.intercept = torch.nn.Parameter(
          torch.randn(output_dim, requires_grad=True)  
        )
        
    def forward(self, X):
        return (X @ self.beta + self.intercept).squeeze()
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_model(
  64, 10
).fit(
  X_train, y_train, X_test, y_test, acc_step=10, n=3000
)

NN Layers

class mnist_nn_model(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, X):
        return self.linear(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

NN linear layer

Applies a linear transform to the incoming data (\(X\)): \[y = X A^T+b\]

X.shape
(1797, 64)
model = mnist_nn_model(64, 10)
model.parameters()
<generator object Module.parameters at 0x329a919a0>
list(model.parameters())[0].shape  # A - weights (betas)
torch.Size([10, 64])
list(model.parameters())[1].shape  # b - bias
torch.Size([10])

Performance

loss, train_acc, test_acc = model.fit(X_train, y_train, X_test, y_test, n=1000)
train_acc[-5:]
[tensor(0.9916), tensor(0.9916), tensor(0.9916), tensor(0.9916), tensor(0.9916)]
test_acc[-5:]
[tensor(0.9694), tensor(0.9694), tensor(0.9667), tensor(0.9667), tensor(0.9667)]

Feedforward Neural Network

FNN Model

class mnist_fnn_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU(), seed=1234):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl(out)
        out = self.l2(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Non-linear activation functions

\[\text{Tanh}(x) = \frac{\exp(x)-\exp(-x)}{\exp(x) + \exp(-x)}\]

\[\text{ReLU}(x) = \max(0,x)\]

Model parameters

model = mnist_fnn_model(64,64,10)
len(list(model.parameters()))
4
for i, p in enumerate(model.parameters()):
  print("Param", i, p.shape)
Param 0 torch.Size([64, 64])
Param 1 torch.Size([64])
Param 2 torch.Size([10, 64])
Param 3 torch.Size([10])

Performance - ReLU

loss, train_acc, test_acc = mnist_fnn_model(64,64,10).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9979123173277662, 0.9979123173277662, 0.9979123173277662, 0.9979123173277662, 0.9979123173277662]
test_acc[-5:]
[0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn_model(64,64,10, nl_step=torch.nn.Tanh()).fit(
  X_train, y_train, X_test, y_test, n=2000
)
train_acc[-5:]
[0.9958246346555324, 0.9958246346555324, 0.9958246346555324, 0.9958246346555324, 0.9958246346555324]
test_acc[-5:]
[0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222, 0.9722222222222222]

Adding another layer

class mnist_fnn2_model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nl_step = torch.nn.ReLU(), seed=1234):
        super().__init__()
        self.l1 = torch.nn.Linear(input_dim, hidden_dim)
        self.nl = nl_step
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.nl = nl_step
        self.l3 = torch.nn.Linear(hidden_dim, output_dim)
        
    def forward(self, X):
        out = self.l1(X)
        out = self.nl(out)
        out = self.l2(out)
        out = self.nl(out)
        out = self.l3(out)
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance - relu

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.ReLU()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9895615866388309, 0.9902574808629089, 0.9902574808629089, 0.9909533750869868, 0.9909533750869868]
test_acc[-5:]
[0.9611111111111111, 0.9611111111111111, 0.9611111111111111, 0.9611111111111111, 0.9611111111111111]

Performance - tanh

loss, train_acc, test_acc = mnist_fnn2_model(
  64,64,10, nl_step=torch.nn.Tanh()
).fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9791231732776617, 0.9798190675017397, 0.9798190675017397, 0.9805149617258176, 0.9805149617258176]
test_acc[-5:]
[0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9638888888888889, 0.9666666666666667]

Convolutional NN

2d convolutions

nn.Conv2d()

cv = torch.nn.Conv2d(
  in_channels=1, out_channels=4, 
  kernel_size=3, 
  stride=1, padding=1
)
list(cv.parameters())[0] # kernel weights
Parameter containing:
tensor([[[[-0.1000, -0.0723, -0.2855],
          [-0.2065, -0.1656,  0.1223],
          [-0.2908, -0.2739, -0.1053]]],

        [[[ 0.3038, -0.0362, -0.0239],
          [ 0.1094, -0.0125,  0.0823],
          [-0.0237,  0.1522,  0.1868]]],

        [[[ 0.1054, -0.0330,  0.0633],
          [-0.1794,  0.1278,  0.0690],
          [-0.0593,  0.2729,  0.1282]]],

        [[[-0.3325, -0.0735, -0.0929],
          [-0.3116, -0.0260, -0.1559],
          [ 0.1824, -0.2539,  0.0196]]]],
       requires_grad=True)
list(cv.parameters())[1] # biases
Parameter containing:
tensor([-0.2050,  0.0266,  0.3102,  0.2498],
       requires_grad=True)

Applying Conv2d()

X_train[[0]]
tensor([[ 0.,  0.,  0., 10., 11.,  0.,  0.,
          0.,  0.,  0.,  9., 16.,  6.,  0.,
          0.,  0.,  0.,  0., 15., 13.,  0.,
          0.,  0.,  0.,  0.,  0., 14., 10.,
          0.,  0.,  0.,  0.,  0.,  1., 15.,
         12.,  8.,  2.,  0.,  0.,  0.,  0.,
         12., 16., 16., 16., 10.,  1.,  0.,
          0.,  7., 16., 12., 12., 16.,  4.,
          0.,  0.,  0.,  9., 15., 12.,  5.,
          0.]])
X_train[[0]].shape
torch.Size([1, 64])
cv(X_train[[0]])
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 64]
X_train[[0]].view(1,8,8)
tensor([[[ 0.,  0.,  0., 10., 11.,  0.,  0.,
           0.],
         [ 0.,  0.,  9., 16.,  6.,  0.,  0.,
           0.],
         [ 0.,  0., 15., 13.,  0.,  0.,  0.,
           0.],
         [ 0.,  0., 14., 10.,  0.,  0.,  0.,
           0.],
         [ 0.,  1., 15., 12.,  8.,  2.,  0.,
           0.],
         [ 0.,  0., 12., 16., 16., 16., 10.,
           1.],
         [ 0.,  0.,  7., 16., 12., 12., 16.,
           4.],
         [ 0.,  0.,  0.,  9., 15., 12.,  5.,
           0.]]])
cv(X_train[[0]].view(1,8,8))
tensor([[[-2.0501e-01, -1.1529e+00,
          -3.1323e+00, -8.1468e+00,
          -1.0387e+01, -4.2215e+00,
          -2.0501e-01, -2.0501e-01],
         [-2.0501e-01, -6.8433e-01,
          -8.0716e+00, -1.5765e+01,
          -1.0078e+01, -2.5439e+00,
          -2.0501e-01, -2.0501e-01],
         [-2.0501e-01, -2.4153e+00,
          -1.1206e+01, -1.6035e+01,
          -7.8311e+00, -8.0488e-01,
          -2.0501e-01, -2.0501e-01],
         [-3.1033e-01, -4.6301e+00,
          -1.1759e+01, -1.5682e+01,
          -9.4612e+00, -3.0791e+00,
          -7.8660e-01, -2.0501e-01],
         [-8.2744e-02, -3.7980e+00,
          -1.0267e+01, -1.5991e+01,
          -1.5483e+01, -1.2276e+01,
          -8.1149e+00, -3.3868e+00],
         [-4.9056e-01, -3.8305e+00,
          -8.4485e+00, -1.5709e+01,
          -1.5754e+01, -1.4341e+01,
          -1.3536e+01, -8.1840e+00],
         [-2.0501e-01, -2.7757e+00,
          -5.7917e+00, -1.3802e+01,
          -1.9343e+01, -1.6500e+01,
          -1.2310e+01, -6.6976e+00],
         [-2.0501e-01, -2.2039e+00,
          -4.1794e+00, -5.1440e+00,
          -8.9735e+00, -1.1314e+01,
          -7.0096e+00, -3.1264e+00]],

        [[ 2.6574e-02,  1.7082e+00,
           5.2086e+00,  4.1493e+00,
           1.5172e+00,  1.0881e+00,
           2.6574e-02,  2.6574e-02],
         [ 2.6574e-02,  3.5696e+00,
           5.7036e+00,  2.3036e+00,
           4.0340e+00,  4.0246e+00,
           2.6574e-02,  2.6574e-02],
         [ 2.6574e-02,  3.6613e+00,
           4.2002e+00,  4.7076e+00,
           5.8552e+00,  1.8491e+00,
           2.6574e-02,  2.6574e-02],
         [ 2.1342e-01,  3.7748e+00,
           4.3228e+00,  8.4855e+00,
           6.3766e+00,  1.4114e-01,
          -2.0892e-02,  2.6574e-02],
         [ 1.0882e-01,  3.1557e+00,
           5.0065e+00,  1.1208e+01,
           9.4874e+00,  4.8014e+00,
           1.5747e+00, -5.8542e-02],
         [ 2.6906e-03,  1.9271e+00,
           4.7223e+00,  1.0899e+01,
           9.8909e+00,  9.2895e+00,
           5.2407e+00,  1.3377e+00],
         [ 2.6574e-02,  3.1573e-01,
           2.1205e+00,  8.4367e+00,
           1.0826e+01,  8.9534e+00,
           6.4199e+00,  4.6106e+00],
         [ 2.6574e-02, -1.4061e-01,
           1.3146e-01,  2.4088e+00,
           5.9507e+00,  4.7586e+00,
           4.2483e+00,  5.2893e+00]],

        [[ 3.1021e-01,  1.4639e+00,
           5.5077e+00,  6.9485e+00,
           6.0934e-01, -2.0192e+00,
           3.1021e-01,  3.1021e-01],
         [ 3.1021e-01,  2.8545e+00,
           8.9574e+00,  4.1778e+00,
          -1.8735e+00,  3.9341e-01,
           3.1021e-01,  3.1021e-01],
         [ 3.1021e-01,  3.7101e+00,
           8.9423e+00,  1.9791e+00,
          -1.1265e+00,  9.4271e-01,
           3.1021e-01,  3.1021e-01],
         [ 4.3840e-01,  4.4218e+00,
           8.6897e+00,  3.6389e+00,
           1.6139e+00,  3.8120e-01,
           1.9152e-01,  3.1021e-01],
         [ 3.7926e-01,  3.8979e+00,
           8.3727e+00,  6.5559e+00,
           5.8394e+00,  3.8290e+00,
           1.8589e+00, -1.0376e-02],
         [ 3.7350e-01,  2.9524e+00,
           7.2797e+00,  8.4873e+00,
           5.5800e+00,  5.5654e+00,
           3.1641e+00, -1.2140e+00],
         [ 3.1021e-01,  1.5529e+00,
           4.0799e+00,  8.0560e+00,
           7.0707e+00,  5.6128e+00,
           2.5503e+00, -1.3246e+00],
         [ 3.1021e-01,  7.5318e-01,
           1.7134e+00,  3.4657e+00,
           3.4911e+00,  1.3795e+00,
          -2.1306e-01,  9.6804e-01]],

        [[ 2.4980e-01,  4.2660e-01,
          -3.2802e+00, -4.0289e+00,
          -1.7578e+00, -2.0832e+00,
           2.4980e-01,  2.4980e-01],
         [ 2.4980e-01, -8.5860e-01,
          -6.9616e+00, -6.2290e+00,
          -6.6547e+00, -5.2774e+00,
           2.4980e-01,  2.4980e-01],
         [ 2.4980e-01, -2.6499e+00,
          -7.6744e+00, -9.4748e+00,
          -7.7384e+00, -1.7453e+00,
           2.4980e-01,  2.4980e-01],
         [ 2.6944e-01, -3.2857e+00,
          -7.3753e+00, -1.0471e+01,
          -6.9922e+00,  1.2010e+00,
           6.1457e-01,  2.4980e-01],
         [ 9.3900e-02, -3.1798e+00,
          -7.0145e+00, -1.2934e+01,
          -8.1648e+00, -3.2431e+00,
           2.5189e-02,  1.8197e+00],
         [ 1.5688e-01, -2.9507e+00,
          -6.5706e+00, -1.5564e+01,
          -1.2304e+01, -1.0062e+01,
          -7.6123e+00, -9.8952e-01],
         [ 2.4980e-01, -1.9565e+00,
          -4.6190e+00, -1.2863e+01,
          -1.6834e+01, -1.3935e+01,
          -9.7589e+00, -7.3264e+00],
         [ 2.4980e-01, -4.0061e-01,
          -3.1546e+00, -6.9421e+00,
          -1.2133e+01, -1.1875e+01,
          -9.1577e+00, -6.9226e+00]]],
       grad_fn=<SqueezeBackward1>)

Pooling

x = torch.tensor(
  [[[0,0,0,0],
    [0,1,2,0],
    [0,3,4,0],
    [0,0,0,0]]],
  dtype=torch.float
)
x.shape
torch.Size([1, 4, 4])
torch.nn.MaxPool2d(
  kernel_size=2, stride=1
)(x)
tensor([[[1., 2., 2.],
         [3., 4., 4.],
         [3., 4., 4.]]])
torch.nn.MaxPool2d(
  kernel_size=3, stride=1, padding=1
)(x)
tensor([[[1., 2., 2., 2.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.],
         [3., 4., 4., 4.]]])
torch.nn.AvgPool2d(
  kernel_size=2
)(x)
tensor([[[0.2500, 0.5000],
         [0.7500, 1.0000]]])
torch.nn.AvgPool2d(
  kernel_size=2, padding=1
)(x)
tensor([[[0.0000, 0.0000, 0.0000],
         [0.0000, 2.5000, 0.0000],
         [0.0000, 0.0000, 0.0000]]])

Convolutional model

class mnist_conv_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn  = torch.nn.Conv2d(
          in_channels=1, out_channels=8,
          kernel_size=3, stride=1, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.MaxPool2d(kernel_size=2)
        self.lin  = torch.nn.Linear(8 * 4 * 4, 10)
        
    def forward(self, X):
        out = self.cnn(X.view(-1, 1, 8, 8))
        out = self.relu(out)
        out = self.pool(out)
        out = self.lin(out.view(-1, 8 * 4 * 4))
        return out
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      loss_fn = torch.nn.CrossEntropyLoss()
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = loss_fn(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum().item() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum().item() / len(y_test) )
            
      return (losses, train_acc, test_acc)

Performance

loss, train_acc, test_acc = mnist_conv_model().fit(
  X_train, y_train, X_test, y_test, n=1000
)
train_acc[-5:]
[0.9944328462073765, 0.9944328462073765, 0.9944328462073765, 0.9951287404314544, 0.9951287404314544]
test_acc[-5:]
[0.9833333333333333, 0.9833333333333333, 0.9833333333333333, 0.9833333333333333, 0.9833333333333333]

Organizing models

class mnist_conv_model2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
          torch.nn.Unflatten(1, (1,8,8)),
          torch.nn.Conv2d(
            in_channels=1, out_channels=8,
            kernel_size=3, stride=1, padding=1
          ),
          torch.nn.ReLU(),
          torch.nn.MaxPool2d(kernel_size=2),
          torch.nn.Flatten(),
          torch.nn.Linear(8 * 4 * 4, 10)
        )
        
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X_train, y_train, X_test, y_test, lr=0.001, n=1000, acc_step=10):
      opt = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9) 
      losses, train_acc, test_acc = [], [], []
      
      for i in range(n):
          opt.zero_grad()
          loss = torch.nn.CrossEntropyLoss()(self(X_train), y_train)
          loss.backward()
          opt.step()
          
          losses.append(loss.item())
          
          if (i+1) % acc_step == 0:
            val, train_pred = torch.max(self(X_train), dim=1)
            val, test_pred  = torch.max(self(X_test), dim=1)
            
            train_acc.append( (train_pred == y_train).sum() / len(y_train) )
            test_acc.append( (test_pred == y_test).sum() / len(y_test) )
            
      return (losses, train_acc, test_acc)

A bit more on non-linear
activation layers

Non-linear functions

df = pd.read_csv("data/gp.csv")
X = torch.tensor(df["x"], dtype=torch.float32).reshape(-1,1)
y = torch.tensor(df["y"], dtype=torch.float32)

Linear regression

class lin_reg(torch.nn.Module):
    def __init__(self, X):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, self.p)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m1 = lin_reg(X)
loss = m1.fit(X,y, n=2000)

Training loss:

Predictions

Double linear regression

class dbl_lin_reg(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

m2 = dbl_lin_reg(X, hidden_dim=10)
loss = m2.fit(X,y, n=2000)

Training loss:

Predictions

Non-linear regression w/ ReLU

class lin_reg_relu(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Hidden dimensions

Non-linear regression w/ Tanh

class lin_reg_tanh(torch.nn.Module):
    def __init__(self, X, hidden_dim=10):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.Tanh(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Tanh & hidden dimension

Three layers

class three_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

Five layers

class five_layers(torch.nn.Module):
    def __init__(self, X, hidden_dim=100):
        super().__init__()
        self.n = X.shape[0]
        self.p = X.shape[1]
        self.model = torch.nn.Sequential(
          torch.nn.Linear(self.p, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, hidden_dim),
          torch.nn.ReLU(),
          torch.nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, X):
        return self.model(X)
    
    def fit(self, X, y, n=1000):
      losses = []
      opt = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
      for i in range(n):
          loss = torch.nn.MSELoss()(self(X).squeeze(), y)
          loss.backward()
          opt.step()
          opt.zero_grad()
          losses.append(loss.item())
      
      return losses

Model results

:::